import torch

# 检查是否支持华为昇腾NPU
def env_test():
    print("CUDA is available: ", torch.cuda.is_available())
    print("Ascend NPU is available", torch_npu.npu.is_available())

def get_device(is_check_npu):
    if (is_check_npu):
        import torch_npu
        if torch_npu.npu.is_available() :
            device = torch.device('cuda:0')
        elif torch.cuda.is_available() :
            device = 'cuda'
        else:
            device = 'cpu'
    else:
        if torch.cuda.is_available() :
            device = 'cuda'
        else:
            device = 'cpu'
    return device
